import torch
import torch.nn as nn

from onmt.decoders.decoder import DecoderBase
from onmt.modules import MultiHeadedAttention, AverageAttention
from onmt.modules.position_ffn import PositionwiseFeedForward
from onmt.modules.position_ffn import ActivationFunction
from onmt.utils.misc import sequence_mask


class TransformerDecoderLayerBase(nn.Module):
    def __init__(
        self,
        d_model,
        heads,
        d_ff,
        dropout,
        attention_dropout,
        self_attn_type="scaled-dot",
        max_relative_positions=0,
        aan_useffn=False,
        full_context_alignment=False,
        alignment_heads=0,
        pos_ffn_activation_fn=ActivationFunction.relu,
    ):

        super(TransformerDecoderLayerBase, self).__init__()

        if self_attn_type == "scaled-dot":
            self.self_attn = MultiHeadedAttention(
                heads,
                d_model,
                dropout=attention_dropout,
                max_relative_positions=max_relative_positions,
            )
        elif self_attn_type == "average":
            self.self_attn = AverageAttention(
                d_model, dropout=attention_dropout, aan_useffn=aan_useffn
            )

        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout,
                                                    pos_ffn_activation_fn
                                                    )
        self.layer_norm_1 = nn.LayerNorm(d_model, eps=1e-6)
        self.drop = nn.Dropout(dropout)
        self.full_context_alignment = full_context_alignment
        self.alignment_heads = alignment_heads


    def forward(self, *args, **kwargs):

        with_align = kwargs.pop("with_align", False)
        output, attns = self._forward(*args, **kwargs)
        top_attn = attns[:, 0, :, :].contiguous()
        attn_align = None
        
        if with_align:
            
            if self.full_context_alignment:
                _, attns = self._forward(*args, **kwargs, future=True)

            if self.alignment_heads > 0:
                attns = attns[:, : self.alignment_heads, :, :].contiguous()

            attn_align = attns.mean(dim=1)

        return output, top_attn, attn_align 

    def update_dropout(self, dropout, attention_dropout):
        self.self_attn.update_dropout(attention_dropout)
        self.feed_forward.update_dropout(dropout)
        self.drop.p = dropout


    def _compute_dec_mask(self, tgt_pad_mask, future):

        tgt_len = tgt_pad_mask.size(-1)
        if not future:  
            future_mask = torch.ones(
                [tgt_len, tgt_len],
                device=tgt_pad_mask.device,
                dtype=torch.uint8,
            )

            future_mask = future_mask.triu_(1).view(1, tgt_len, tgt_len)

            try:
                future_mask = future_mask.bool()

            except AttributeError:
                pass
            dec_mask = torch.gt(tgt_pad_mask + future_mask, 0)

        else:  
            dec_mask = tgt_pad_mask

        return dec_mask

    def _forward_self_attn(self, inputs_norm, dec_mask, layer_cache, step):

        if isinstance(self.self_attn, MultiHeadedAttention):
            return self.self_attn(
                inputs_norm,
                inputs_norm,
                inputs_norm,
                mask=dec_mask,
                layer_cache=layer_cache,
                attn_type="self",
            )
        
        elif isinstance(self.self_attn, AverageAttention):
            return self.self_attn(
                inputs_norm, mask=dec_mask, layer_cache=layer_cache, step=step
            )
        else:
            raise ValueError(
                f"self attention {type(self.self_attn)} not supported"
            )


class TransformerDecoderLayer(TransformerDecoderLayerBase):

    def __init__(
        self,
        d_model,
        heads,
        d_ff,
        dropout,
        attention_dropout,
        self_attn_type="scaled-dot",
        max_relative_positions=0,
        aan_useffn=False,
        full_context_alignment=False,
        alignment_heads=0,
        pos_ffn_activation_fn=ActivationFunction.relu,
    ):

        super(TransformerDecoderLayer, self).__init__(
            d_model,
            heads,
            d_ff,
            dropout,
            attention_dropout,
            self_attn_type,
            max_relative_positions,
            aan_useffn,
            full_context_alignment,
            alignment_heads,
            pos_ffn_activation_fn=pos_ffn_activation_fn,
        )

        self.context_attn = MultiHeadedAttention(
            heads, d_model, dropout=attention_dropout
        )

        self.layer_norm_2 = nn.LayerNorm(d_model, eps=1e-6)

    def update_dropout(self, dropout, attention_dropout):
        super(TransformerDecoderLayer, self).update_dropout(
            dropout, attention_dropout
        )
        self.context_attn.update_dropout(attention_dropout)


    def _forward(
        self,
        inputs,
        memory_bank,
        src_pad_mask,
        tgt_pad_mask,
        layer_cache=None,
        step=None,
        future=False,
    ):
        dec_mask = None

        if inputs.size(1) > 1:
            dec_mask = self._compute_dec_mask(tgt_pad_mask, future)

        inputs_norm = self.layer_norm_1(inputs)

        query, _ = self._forward_self_attn(
            inputs_norm, dec_mask, layer_cache, step
        )

        query = self.drop(query) + inputs

        query_norm = self.layer_norm_2(query)

        mid, attns = self.context_attn(
            memory_bank,
            memory_bank,
            query_norm,
            mask=src_pad_mask,
            layer_cache=layer_cache,
            attn_type="context",
        )

        output = self.feed_forward(self.drop(mid) + query)

        return output, attns


class TransformerDecoderBase(DecoderBase):

    def __init__(self, d_model, copy_attn, alignment_layer):
        
        super(TransformerDecoderBase, self).__init__()
        self.state = {}
        self._copy = copy_attn
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
        self.alignment_layer = alignment_layer

    @classmethod
    def from_opt(cls, opt, embeddings):
        return cls(
            opt.dec_layers,
            opt.dec_rnn_size,
            opt.heads,
            opt.transformer_ff,
            opt.copy_attn,
            opt.self_attn_type,
            opt.dropout[0] if type(opt.dropout) is list else opt.dropout,
            opt.attention_dropout[0] if type(opt.attention_dropout) is list else opt.attention_dropout,
            embeddings,
            opt.max_relative_positions,
            opt.aan_useffn,
            opt.full_context_alignment,
            opt.alignment_layer,
            alignment_heads=opt.alignment_heads,
            pos_ffn_activation_fn=opt.pos_ffn_activation_fn,
        )


    def init_state(self, src, memory_bank, enc_hidden):
        
        self.state["src"] = src
        self.state["cache"] = None

    def map_state(self, fn):

        def _recursive_map(struct, batch_dim=0):
            for k, v in struct.items():

                if v is not None:

                    if isinstance(v, dict):
                        _recursive_map(v)
                    else:
                        struct[k] = fn(v, batch_dim)

        if self.state["src"] is not None:
            self.state["src"] = fn(self.state["src"], 1)
        if self.state["cache"] is not None:
            _recursive_map(self.state["cache"])

    def update_dropout(self, dropout, attention_dropout):
        self.embeddings.update_dropout(dropout)
        for layer in self.transformer_layers:
            layer.update_dropout(dropout, attention_dropout)


class TransformerDecoder(TransformerDecoderBase):

    def __init__(
        self,
        num_layers,
        d_model,
        heads,
        d_ff,
        copy_attn,
        self_attn_type,
        dropout,
        attention_dropout,
        max_relative_positions,
        aan_useffn,
        full_context_alignment,
        alignment_layer,
        alignment_heads,
        pos_ffn_activation_fn=ActivationFunction.relu,
    ):
        super(TransformerDecoder, self).__init__(
            d_model, copy_attn, alignment_layer
        )

        self.transformer_layers = nn.ModuleList(
            [
                TransformerDecoderLayer(
                    d_model,
                    heads,
                    d_ff,
                    dropout,
                    attention_dropout,
                    self_attn_type=self_attn_type,
                    max_relative_positions=max_relative_positions,
                    aan_useffn=aan_useffn,
                    full_context_alignment=full_context_alignment,
                    alignment_heads=alignment_heads,
                    pos_ffn_activation_fn=pos_ffn_activation_fn,
                )
                for i in range(num_layers)
            ]
        )

    def detach_state(self):
        self.state["src"] = self.state["src"].detach()


    def forward(self, tgt_emb, memory_bank, src_pad_mask=None, tgt_pad_mask=None, step=None, **kwargs):

        if step == 0:
            self._init_cache(memory_bank)

        batch_size, src_len, src_dim = memory_bank.size()

        device = memory_bank.device

        if src_pad_mask is None:
            src_pad_mask = torch.zeros((batch_size, 1, src_len), dtype=torch.bool, device=device)

        output = tgt_emb
        batch_size, tgt_len, tgt_dim = tgt_emb.size()

        if tgt_pad_mask is None:
            tgt_pad_mask = torch.zeros((batch_size, 1, tgt_len), dtype=torch.bool, device=device)

        future = kwargs.pop("future", False)
        with_align = kwargs.pop("with_align", False)

        attn_aligns = []
        hiddens = []

        for i, layer in enumerate(self.transformer_layers):
            layer_cache = (
                self.state["cache"]["layer_{}".format(i)]
                if step is not None
                else None
            )
            output, attn, attn_align = layer(
                output,
                memory_bank,
                src_pad_mask,
                tgt_pad_mask,
                layer_cache=layer_cache,
                step=step,
                with_align=with_align,
                future=future
            )

            hiddens.append(output)
            if attn_align is not None:
                attn_aligns.append(attn_align)

        output = self.layer_norm(output)  # (B, L, D)

        attns = {"std": attn}
        if self._copy:
            attns["copy"] = attn

        if with_align:
            attns["align"] = attn_aligns[self.alignment_layer]  # `(B, Q, K)`

        return output, attns, hiddens

    
    def _init_cache(self, memory_bank):

        self.state["cache"] = {}
        for i, layer in enumerate(self.transformer_layers):
            layer_cache = {"memory_keys": None, "memory_values": None, "self_keys": None, "self_values": None}
            self.state["cache"]["layer_{}".format(i)] = layer_cache
